
import torch

from replay.batch import Batch
from replay.oegn import Oegn
from tools.utils import preprocess

class GngManager():
    def __init__(self,args,obs_shape,action_space,goal_space,context,save_dir,oegn=None,**kwargs):
        self.args=args
        self.context=context
        self.obs_shape=obs_shape
        self.action_space=action_space
        self.goal_space=goal_space
        self.save_dir=save_dir

        if oegn is None:#steps_init,args,obs_size,action_space,goal_space,context,save_dir=None,
            self.oegn=Oegn(300,self.args, self.obs_shape,self.action_space,self.goal_space,self.context,save_dir=save_dir)
        else:
            self.oegn=oegn

        #Maintained data
        self.actual_goal_steps = torch.zeros((1,),dtype=torch.long)
        self.prev_obs = torch.zeros((1, *self.obs_shape))
        self.prev_state=None
        if self.args.state:
            self.prev_state = torch.zeros((1, 2))
        self.batch = Batch(args,obs_shape, action_space)

        self.max_goalstep=0

    def init(self,obs,state=None):
        self.prev_obs.copy_(obs) #init the rollout
        self.actual_goal_steps.zero_()
        if self.args.state:
            self.prev_state.copy_(state)
        return obs

    def get_evals(self):
        return self.batch

    def can_learn(self):
        return self.oegn.can_learn()

    def get_evals(self):
        return self.batch

    def sample(self,**kwargs):
        return self.oegn.sample(self.batch,**kwargs)

    def insert(self, obs, infos, actions, masks,reward,*args,goal=None,features=None,state=None,act_state=None,**kwargs):
        next_obses = []
        if infos[0]["true_obs"] is None:
            next_obses.append(obs[0:1])
            statet=state
        else:
            next_obses.append(infos[0]["true_obs"].unsqueeze(0))
            statet=torch.tensor(infos[0]["true_state"]) if self.args.state else None
        next_obses = torch.cat(next_obses,dim=0)
        f = None
        if infos[0]["true_obs"] is not None or features is None:
            with torch.no_grad():
                f= self.context.estimator.goal_embed(preprocess(next_obses[0:1], self.args),act=True).cpu()
        else:
            f=features[0:1].cpu()

        self.oegn.insert(f, self.prev_obs[0], next_obses[0], actions[0], masks[0],reward[0], *args, goal=goal[0:1],
                        goal_step=self.actual_goal_steps[0].item(),goal_obs=self.context.coord_actor.goals_obs[0],
                        state=statet,prev_state=self.prev_state,act_state=act_state)

        self.actual_goal_steps[:]+=1
        self.prev_obs[:]=obs
        if self.args.state:
            self.prev_state[:]=state


    def insert_goal_data(self,batch,masks,m_pb):


        embeddings_masked=batch.som_embeddings[:masks.shape[0]][masks].cpu()
        prev_embeddings_masked = batch.som_prev_embeddings[:masks.shape[0]][masks].cpu()

        if self.args.state:
            states = batch.states[:masks.shape[0]][masks].cpu()
            prev_states = batch.prev_states[:masks.shape[0]][masks].cpu()

        nodes=batch.ind[:masks.shape[0]][masks]
        distids = batch.distribution_id[:masks.shape[0]][masks]
        indexes = batch.index[:masks.shape[0]][masks]

        numbers = torch.randint(0, embeddings_masked.shape[0],(min(m_pb.shape[0],self.args.num_som_updates),))
        ###We sequentially feed the oegn algorithm
        for j in range(min(m_pb.shape[0],self.args.num_som_updates)):
            i = numbers[j]
            success= self.args.target_prob >= 20 or (m_pb[i] < self.args.target_prob)

            if not self.args.state:
                insert_node,s_1=self.oegn.step(embeddings_masked[i].cpu().view(1,-1),nodes[i].item(),success,prev_embeddings_masked[i],distids[i])
            else:
                insert_node,s_1=self.oegn.step(states[i:i+1],nodes[i].item(),success,prev_states[i:i+1],distids[i])

            if insert_node and self.args.delete_ins:
                buffer_insert = self.oegn.buffers[insert_node]
                buffer_insert.insert(batch.obs[:masks.shape[0]][masks][i],batch.next_obs[:masks.shape[0]][masks][i],batch.actions[:masks.shape[0]][masks][i],batch.masks[:masks.shape[0]][masks][i],
                                        batch.rewards[:masks.shape[0]][masks][i],goal=batch.goals[:masks.shape[0]][masks][i],goal_step=(batch.goals_step[:masks.shape[0]][masks][i]-1).item(),goal_obs=batch.goals_obs[:masks.shape[0]][masks][i],
                                     embed=embeddings_masked[i].view(1,-1),state=states[i] if self.args.state else None,prev_state=prev_states[i] if self.args.state else None,mode="embed")

            ### We update the embedding of the state of the interaction
            nodeitem = nodes[i].item()
            if not self.oegn.has_deleted and distids[i] and not self.oegn.buffers[nodeitem].is_deleted():
                nodeitem = nodes[i].item()
                indexitem = indexes[i].item()
                sid = self.oegn.buffers[nodeitem].learnDataStore.sid
                self.context.memory2.embeds[sid, indexitem, :] = embeddings_masked[i,:]
                self.oegn.buffers[nodeitem].maj_select_index(indexitem, s_1 == nodes[i])


    def compute_returns(self, values, gamma):
        batch=self.get_evals()
        if values is None:
            return self.batch.irewards.view(-1, 1)

        next_values = values.view(self.args.batch_size, -1)
        returns = batch.irewards + gamma * next_values
        return returns.view(-1, 1)

    def after_update(self):
        self.batch.irewards.zero_()

    def change_goal(self):
        self.actual_goal_steps[:] = 0

    def load(self):
        self.oegn.load()

    def save(self):
        self.oegn.save()

    def clone(self,**kwargs):
        gngrp = GngManager(self.args,self.obs_shape,self.action_space,self.goal_space,self.context,self.save_dir,oegn=None)
        return gngrp
